#ifdef USE_XETLA_XE_HPC

#include "hgemm_xehpc.hpp"

namespace torch_ipex::xpu::xetla::xehpc {

#include "hgemm_policy_xehpc.h"

#if 0
static GemmPolicyT hgemm_policy_traits[HGEMM_NUM_POLICIES] = {
    HGEMM_ENUMERATE_POLICIES_COMMA(HGEMM_POLICY_TRAITS)};
#endif

static std::unordered_map<GemmShapeT, int, GemmShapeT>
    hgemm_b_row_special_table = {
        {{1, 4096, 16384}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 16384, 4096}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1, 16384, 3072}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1, 32064, 3072}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1, 3072, 3072}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{1, 3072, 8192}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{1, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 4096, 16384}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 16384, 4096}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 16384, 3072}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 32064, 3072}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 3072, 3072}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{4, 3072, 8192}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{4, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 4096, 16384}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{32, 16384, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{32, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{33, 4096, 16384}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{33, 16384, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{33, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{64, 4096, 16384}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{64, 16384, 4096}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{64, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{65, 16384, 4096}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{128, 16384, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{130, 4096, 16384}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{130, 16384, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{130, 4096, 4096}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{256, 4096, 16384}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{256, 16384, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{256, 4096, 4096}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{512, 4096, 16384}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{512, 16384, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{512, 4096, 4096}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{513, 4096, 16384}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{513, 16384, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{513, 4096, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{1024, 4096, 16384}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 4096, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{2016, 4096, 16384}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{2016, 4096, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1, 50400, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{1, 50272, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{4, 50400, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{4, 50272, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{1, 250880, 4096}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{4, 250880, 4096}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{1, 11008, 4096}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 11008, 4096}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{32, 11008, 4096}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{64, 11008, 4096}, hgemm_policy::_64x256_64x16x16_2_true_},
        {{128, 11008, 4096}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{256, 11008, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{512, 11008, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1, 32000, 4096}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 32000, 4096}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1, 13824, 5120}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1, 5120, 5120}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{4, 13824, 5120}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{4, 5120, 5120}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{32, 13824, 5120}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{32, 5120, 5120}, hgemm_policy::_32x128_32x16x16_4_true_},
        {{64, 13824, 5120}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{64, 5120, 5120}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{128, 13824, 5120}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{128, 5120, 5120}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{256, 13824, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{256, 5120, 5120}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{512, 13824, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{512, 5120, 5120}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 5120, 5120}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{2016, 5120, 5120}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1, 32000, 5120}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 32000, 5120}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1, 7168, 14336}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{1, 1792, 14336}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{4, 7168, 14336}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{4, 1792, 14336}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{32, 7168, 14336}, hgemm_policy::_16x256_16x16x16_2_true_},
        {{32, 1792, 14336}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{33, 7168, 14336}, hgemm_policy::_32x256_32x16x16_2_true_},
        {{33, 1792, 14336}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{64, 7168, 14336}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{64, 1792, 14336}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{65, 7168, 14336}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{65, 1792, 14336}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{1, 14336, 7168}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{1, 14336, 1792}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{4, 14336, 7168}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{4, 14336, 1792}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{32, 14336, 7168}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{32, 14336, 1792}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{33, 14336, 7168}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{33, 14336, 1792}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{64, 14336, 7168}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{64, 14336, 1792}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{65, 14336, 7168}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{65, 14336, 1792}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{1, 250880, 1792}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{1, 2048, 8192}, hgemm_policy::_8x64_8x16x32_8_true_},
        {{1, 3584, 7168}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{1, 3584, 8192}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{1, 7168, 3584}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{1, 7168, 8192}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{1, 8192, 1024}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 8192, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 8192, 3584}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{1, 8192, 7168}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{1, 256, 8192}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{1, 32000, 1024}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1, 32000, 2048}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 2048, 8192}, hgemm_policy::_8x64_8x16x32_8_true_},
        {{4, 3584, 7168}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{4, 3584, 8192}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{4, 7168, 3584}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{4, 7168, 8192}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{4, 8192, 1024}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 8192, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 8192, 3584}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{4, 8192, 7168}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{4, 256, 8192}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{4, 32000, 1024}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 32000, 2048}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1024, 2048, 8192}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{1024, 7168, 8192}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 8192, 1024}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 8192, 2048}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 8192, 3584}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 8192, 7168}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 256, 8192}, hgemm_policy::_32x128_32x16x16_4_true_},
        {{1, 2048, 4096}, hgemm_policy::_8x64_8x16x32_8_true_},
        {{1, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 4096, 8192}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 8192, 4096}, hgemm_policy::_8x256_8x16x16_2_true_},
        {{4, 2048, 4096}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{4, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 4096, 8192}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 8192, 4096}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{32, 2048, 4096}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{32, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 4096, 8192}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 8192, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1024, 2048, 4096}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{1024, 4096, 2048}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 4096, 8192}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 8192, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1, 2560, 5120}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{1, 2560, 8192}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{1, 5120, 6912}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{1, 5120, 2560}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{1, 6912, 5120}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{4, 2560, 5120}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{4, 2560, 8192}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{4, 5120, 6912}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{4, 5120, 2560}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{4, 6912, 5120}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{32, 2560, 5120}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{32, 5120, 6912}, hgemm_policy::_32x128_32x16x16_4_true_},
        {{32, 5120, 2560}, hgemm_policy::_32x128_32x16x16_4_true_},
        {{32, 6912, 5120}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{33, 2560, 5120}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{33, 5120, 6912}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{33, 5120, 2560}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{33, 6912, 5120}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{1024, 2560, 5120}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 5120, 6912}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 5120, 2560}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 6912, 5120}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1, 32000, 8192}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 32000, 8192}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{32, 7168, 8192}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{32, 8192, 7168}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{32, 2048, 8192}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{32, 8192, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 5504, 4096}, hgemm_policy::_64x128_64x16x16_4_true_},
        {{1, 4096, 5504}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{1, 2048, 4096}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{1, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 5504, 4096}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{4, 4096, 5504}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 2048, 4096}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{4, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 5504, 4096}, hgemm_policy::_32x128_32x16x16_4_true_},
        {{32, 4096, 5504}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 2048, 4096}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{32, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1024, 5504, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 4096, 5504}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 2048, 4096}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{1024, 4096, 2048}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1, 50272, 7168}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{4, 50272, 7168}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{32, 3584, 7168}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{32, 7168, 3584}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{1024, 7168, 14336}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 3584, 7168}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 3584, 8192}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 7168, 3584}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 5120, 13696}, hgemm_policy::_32x128_8x16x32_1_true_},
        {{1, 13696, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{4, 5120, 13696}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{4, 13696, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{16, 5120, 13696}, hgemm_policy::_64x128_64x16x16_4_true_},
        {{16, 13696, 5120}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{32, 5120, 13696}, hgemm_policy::_32x128_32x16x16_4_true_},
        {{32, 13696, 5120}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{1024, 5120, 13696}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 125696, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{1, 5120, 125696}, hgemm_policy::_32x128_8x16x32_1_true_},
        {{4, 125696, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{4, 5120, 125696}, hgemm_policy::_32x128_8x16x32_1_true_},
        {{32, 4608, 4096}, hgemm_policy::_32x128_32x16x16_4_true_},
        {{32, 4096, 4608}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 4096, 13696}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 13696, 4096}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{1024, 4608, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 4096, 4608}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 4096, 13696}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 65024, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{1, 4096, 65024}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{1, 4608, 4096}, hgemm_policy::_8x32_8x16x16_4_true_},
        {{1, 4096, 4608}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 13696, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{1, 4096, 13696}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 65024, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{4, 4096, 65024}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{4, 4608, 4096}, hgemm_policy::_32x128_32x16x16_4_true_},
        {{4, 4096, 4608}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 13696, 4096}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{4, 4096, 13696}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1023, 3072, 3072}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1023, 3072, 8192}, hgemm_policy::_256x256_32x64x32_1_true_},
};

static std::unordered_map<GemmShapeT, int, GemmShapeT>
    hgemm_qkv_b_row_special_table = {
        {{1, 3072, 3072}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{1, 4096, 16384}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{1, 16384, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{1, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 4096, 16384}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{4, 16384, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{4, 3072, 3072}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{4, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 4096, 16384}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{32, 16384, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{32, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{33, 4096, 16384}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{33, 16384, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{33, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{64, 4096, 16384}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{64, 16384, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{64, 4096, 4096}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{65, 16384, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{128, 16384, 4096}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{130, 4096, 16384}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{130, 16384, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{130, 4096, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{256, 4096, 16384}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{256, 16384, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{256, 4096, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{512, 4096, 16384}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{512, 16384, 4096}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{512, 4096, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{513, 4096, 16384}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{513, 16384, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{513, 4096, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{1024, 4096, 16384}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 16384, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 4096, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1028, 4096, 16384}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1028, 16384, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{1028, 4096, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{2016, 4096, 16384}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{2016, 16384, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{2016, 4096, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 50400, 4096}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{1, 50272, 4096}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 50400, 4096}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 50272, 4096}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{1, 11008, 4096}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{4, 11008, 4096}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{32, 11008, 4096}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{64, 11008, 4096}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{128, 11008, 4096}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{256, 11008, 4096}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{512, 11008, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 11008, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{2016, 11008, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 32000, 4096}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 32000, 4096}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1, 13824, 5120}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{1, 5120, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{4, 13824, 5120}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 5120, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{32, 13824, 5120}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{32, 5120, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{64, 13824, 5120}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{64, 5120, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{128, 13824, 5120}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{128, 5120, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{256, 13824, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{256, 5120, 5120}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{512, 13824, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{512, 5120, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 13824, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 5120, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{2016, 13824, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{2016, 5120, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 32000, 5120}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 32000, 5120}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{1, 7168, 14336}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{1, 1792, 14336}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{4, 7168, 14336}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{4, 1792, 14336}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{32, 7168, 14336}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{32, 1792, 14336}, hgemm_policy::_64x128_64x16x16_4_true_},
        {{33, 7168, 14336}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{33, 1792, 14336}, hgemm_policy::_64x128_64x16x16_4_true_},
        {{64, 7168, 14336}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{64, 1792, 14336}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{65, 7168, 14336}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{65, 1792, 14336}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{1, 14336, 7168}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{1, 14336, 1792}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 14336, 7168}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 14336, 1792}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{32, 14336, 7168}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{32, 14336, 1792}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{33, 14336, 7168}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{33, 14336, 1792}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{64, 14336, 7168}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{64, 14336, 1792}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{65, 14336, 7168}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{65, 14336, 1792}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{1, 250880, 1792}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{1, 2048, 8192}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{1, 3584, 7168}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{1, 7168, 3584}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{1, 7168, 8192}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{1, 8192, 2048}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{1, 8192, 7168}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{1, 64, 8192}, hgemm_policy::_8x32_8x16x16_8_true_},
        {{1, 128, 8192}, hgemm_policy::_8x32_8x16x16_8_true_},
        {{1, 256, 8192}, hgemm_policy::_8x32_8x16x16_8_true_},
        {{1, 32000, 2048}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 2048, 8192}, hgemm_policy::_8x128_8x16x16_2_true_},
        {{4, 3584, 7168}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 7168, 3584}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{4, 7168, 8192}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{4, 8192, 2048}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 8192, 7168}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 64, 8192}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{4, 128, 8192}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{4, 256, 8192}, hgemm_policy::_16x64_16x16x16_8_true_},
        {{4, 32000, 2048}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{8, 128, 8192}, hgemm_policy::_8x32_8x16x16_8_true_},
        {{32, 128, 8192}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{1024, 2048, 8192}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 7168, 8192}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 8192, 2048}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 8192, 7168}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 64, 8192}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{1024, 128, 8192}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{1024, 256, 8192}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{8192, 128, 8192}, hgemm_policy::_128x128_32x32x32_2_true_},
        {{1, 2048, 4096}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{1, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 4096, 8192}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 8192, 4096}, hgemm_policy::_32x64_8x16x16_2_true_},
        {{4, 2048, 4096}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{4, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 4096, 8192}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 8192, 4096}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{32, 2048, 4096}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{32, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 4096, 8192}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{32, 8192, 4096}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{1024, 2048, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 4096, 2048}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 4096, 8192}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 8192, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 2560, 5120}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 5120, 6912}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{1, 5120, 2560}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{1, 6912, 5120}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{4, 2560, 5120}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 5120, 6912}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{4, 5120, 2560}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{4, 6912, 5120}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{32, 2560, 5120}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 5120, 6912}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{32, 5120, 2560}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{32, 6912, 5120}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{33, 2560, 5120}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{33, 5120, 6912}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{33, 5120, 2560}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{33, 6912, 5120}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{1024, 2560, 5120}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1024, 5120, 6912}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 5120, 2560}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 6912, 5120}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 32000, 8192}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{4, 32000, 8192}, hgemm_policy::_8x512_8x16x16_1_true_},
        {{32, 7168, 8192}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{32, 8192, 7168}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{32, 2048, 8192}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{32, 8192, 2048}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{1, 5504, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{1, 4096, 5504}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1, 2048, 4096}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{1, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 5504, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{4, 4096, 5504}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{4, 2048, 4096}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{4, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{32, 5504, 4096}, hgemm_policy::_128x512_64x32x16_1_true_},
        {{32, 4096, 5504}, hgemm_policy::_128x256_64x16x16_1_true_},
        {{32, 2048, 4096}, hgemm_policy::_128x128_16x32x64_1_true_},
        {{32, 4096, 2048}, hgemm_policy::_128x64_16x16x64_1_true_},
        {{1024, 5504, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 4096, 5504}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 2048, 4096}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 4096, 2048}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 50272, 7168}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{4, 50272, 7168}, hgemm_policy::_16x256_8x16x16_1_true_},
        {{32, 3584, 7168}, hgemm_policy::_128x256_32x32x16_1_true_},
        {{32, 7168, 3584}, hgemm_policy::_64x512_64x16x16_1_true_},
        {{1024, 14336, 7168}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 7168, 14336}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 3584, 7168}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1024, 7168, 3584}, hgemm_policy::_256x256_32x64x32_1_true_},
        {{1, 125696, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{1, 5120, 125696}, hgemm_policy::_32x128_8x16x32_1_true_},
        {{4, 125696, 5120}, hgemm_policy::_256x256_32x64x16_1_true_},
        {{4, 5120, 125696}, hgemm_policy::_32x128_8x16x32_1_true_},
        {{32, 128, 4096}, hgemm_policy::_8x64_8x16x32_8_true_},
        {{32, 4096, 128}, hgemm_policy::_32x128_8x16x32_1_true_},
        {{1024, 128, 4096}, hgemm_policy::_32x64_32x16x16_8_true_},
        {{1024, 4096, 128}, hgemm_policy::_256x256_64x32x16_1_true_},
        {{1, 128, 4096}, hgemm_policy::_8x64_8x16x32_8_true_},
        {{1, 4096, 128}, hgemm_policy::_8x64_8x16x16_4_true_},
        {{4, 128, 4096}, hgemm_policy::_8x64_8x16x32_8_true_},
        {{4, 4096, 128}, hgemm_policy::_8x64_8x16x16_4_true_},
};

int hgemm_find_policy_id(
    const int m,
    const int n,
    const int k,
    const bool is_b_row_major) {
  if (is_b_row_major) {
    auto policy_id = hgemm_find_policy_id_(m, n, k, hgemm_b_row_special_table);
    if (policy_id != -1)
      return policy_id;

    if (n == 4096 && m <= 128) {
      return static_cast<int>(hgemm_policy::_128x64_16x16x64_1_true_);
    }
  }

  return -1;
}

int hgemm_qkv_find_policy_id(
    const int m,
    const int n,
    const int k,
    const bool is_b_row_major) {
  if (is_b_row_major) {
    auto policy_id =
        hgemm_find_policy_id_(m, n, k, hgemm_qkv_b_row_special_table);

    if (policy_id != -1)
      return policy_id;

    if (n == 4096 && m <= 128) {
      return static_cast<int>(hgemm_policy::_128x64_16x16x64_1_true_);
    }

    if (m == 1028 && n == 14336) {
      return static_cast<int>(hgemm_policy::_256x256_64x32x16_1_true_);
    }
  }

  return -1;
}

} // namespace torch_ipex::xpu::xetla::xehpc

#endif
